from yaml import emit
from ppdet.modeling.backbones import swin_transformer, res2net
from ppdet.utils.checkpoint_v1 import *


model = res2net.Res2Net(
        depth=50,
        width=26,
        scales=4,
        variant='d',
        lr_mult_list=[1.0, 1.0, 1.0, 1.0],
        groups=1,
        norm_type='bn',
        norm_decay=0.,
        freeze_norm=True,
        freeze_at=0,
        return_idx=[0, 1, 2, 3],
        dcn_v2_stages=[-1],
        num_stages=4
)

weights_path = r'C:\Users\Mrtutu\Desktop\fsdownload\Res2Net50_vd_26w_4s_ssld_pretrained.pdparams'
save_path = 'Res2Net50_vd_26w_4s_ssld_pretrained_pretrained_fixname'

model_dict = model.state_dict()
param_state_dict = paddle.load(weights_path)


# Paddleclas --> PPdet backbone
for k1, k2 in zip(model_dict.keys(), param_state_dict.keys()):
    names = k1.split('.')
    if names[2] == '0': # blocks 0
        stage_num = int(names[0][-1])
        if names[3] == 'branch1':
            if names[4] == 'conv':
                new_key = 'bb_%d_0'%(stage_num-2) + '.short._conv.weight'
                model_dict[k1] =  param_state_dict[new_key]
            elif names[4] == 'norm':
                new_key = 'bb_%d_0'%(stage_num-2) + '.short._batch_norm.' + names[-1]
                model_dict[k1] =  param_state_dict[new_key]
            print('branch1: %s --> %s'%(k1, new_key))
            print(model_dict[k1].shape, '  ', param_state_dict[new_key].shape)
        elif names[3] == 'branch2a':
            if names[4] == 'conv':
                new_key = 'bb_%d_0'%(stage_num-2) + '.conv0._conv.weight'
                model_dict[k1] =  param_state_dict[new_key]
            elif names[4] == 'norm':
                new_key = 'bb_%d_0'%(stage_num-2) + '.conv0._batch_norm.' + names[-1]
                model_dict[k1] =  param_state_dict[new_key]                
            print('branch2a: %s --> %s'%(k1, new_key))
            print(model_dict[k1].shape, '  ', param_state_dict[new_key].shape)
        elif names[3] == 'branch2b':
            conv_num = int(names[4])
            if names[5] == 'conv':
                new_key = 'bb_%d_0'%(stage_num-2) + '.res%da_branch2b_%d'%(stage_num, conv_num+1) + '._conv.weight'
                model_dict[k1] =  param_state_dict[new_key]
            elif names[5] == 'norm':
                new_key = 'bb_%d_0'%(stage_num-2) + '.res%da_branch2b_%d'%(stage_num, conv_num+1) + '._batch_norm.' + names[-1]
                model_dict[k1] =  param_state_dict[new_key]      
            print('branch2b: %s --> %s'%(k1, new_key))
            print(model_dict[k1].shape, '  ', param_state_dict[new_key].shape)        
        elif names[3] == 'branch2c':
            if names[4] == 'conv':
                new_key = 'bb_%d_0'%(stage_num-2) + '.conv2._conv.weight'
                model_dict[k1] =  param_state_dict[new_key]
            elif names[4] == 'norm':
                new_key = 'bb_%d_0'%(stage_num-2) + '.conv2._batch_norm.' + names[-1]
                model_dict[k1] =  param_state_dict[new_key]       
            print('branch2c: %s --> %s'%(k1, new_key))
            print(model_dict[k1].shape, '  ', param_state_dict[new_key].shape)  
    else:
        model_dict[k1] =  param_state_dict[k2]
        print(k1, ' ---> ', k2)
        print(model_dict[k1].shape, '  ', param_state_dict[k2].shape)

    #break


new_model_dict = {}
for k in model_dict.keys():
    new_k = 'backbone.' + k
    new_model_dict[new_k] = model_dict[k]



print('save model...')

paddle.save(new_model_dict, save_path + ".pdparams")


